import logging
import os
from pathlib import Path

from huggingface_hub import create_repo, upload_file, upload_folder

from simpletuner.helpers.publishing.metadata import save_model_card, save_training_config
from simpletuner.helpers.training.state_tracker import StateTracker

logger = logging.getLogger(__name__)
from simpletuner.helpers.training.multi_process import should_log

if should_log():
    logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
else:
    logger.setLevel("ERROR")


LORA_SAFETENSORS_FILENAME = "pytorch_lora_weights.safetensors"
EMA_SAFETENSORS_FILENAME = "ema_model.safetensors"
SLA_ATTENTION_FILENAME = "sla_attention.pt"


class HubManager:
    def __init__(self, config, model, repo_id: str = None):
        self.config = config
        self.model = model
        self.repo_id = repo_id or self.config.hub_model_id or self.config.tracker_project_name
        self.hub_token = self._load_hub_token()
        self.data_backends = StateTracker.get_data_backends(_types=["image", "video", "audio"])
        self._create_repo()
        self.validation_prompts = None
        self.validation_shortnames = None
        self.collected_data_backend_str = None

    def _create_repo(self):
        if not self.config.push_to_hub:
            return
        self._repo_id = create_repo(
            repo_id=self.config.hub_model_id or self.config.tracker_project_name,
            exist_ok=True,
            private=self.config.model_card_private,
        ).repo_id

    def _repo_url(self, subpath: str | None = None) -> str:
        repo_id = getattr(self, "_repo_id", None) or self.config.hub_model_id or self.config.tracker_project_name
        if not repo_id:
            return ""
        base = f"https://huggingface.co/{repo_id}"
        if subpath:
            clean = subpath.strip("/")
            return f"{base}/tree/main/{clean}"
        return base

    def _vae_string(self):
        if "deepfloyd" in self.config.model_type:
            return "\nDeepFloyd Pixel diffusion (no VAE)."
        else:
            return f"\nVAE: {self.config.pretrained_vae_model_name_or_path}"

    def _commit_message(self):
        return (
            f"Trained for {StateTracker.get_epoch() - 1} epochs and {StateTracker.get_global_step()} steps."
            f"\nTrained with datasets {self.collected_data_backend_str}"
            f"\nLearning rate {self.config.learning_rate}, batch size {self.config.train_batch_size}, and {self.config.gradient_accumulation_steps} gradient accumulation steps."
            f"\nTrained with {self.config.prediction_type} prediction type and rescaled_betas_zero_snr={self.config.rescale_betas_zero_snr}"
            f"\nUsing '{self.config.training_scheduler_timestep_spacing}' timestep spacing."
            f"\nBase model: {self.config.pretrained_model_name_or_path}"
            f"{self._vae_string()}"
        )

    def _load_hub_token(self):
        if not self.config.push_to_hub:
            return None
        token_path = os.path.join(os.path.expanduser("~"), ".cache/huggingface/token")
        if os.path.exists(token_path):
            with open(token_path, "r") as f:
                return f.read().strip()
        raise ValueError(
            f"No Hugging Face Hub token found ({token_path}). Please ensure you have logged in with 'huggingface-cli login'."
        )

    def set_validation_prompts(self, validation_prompts):
        self.validation_prompts = validation_prompts.get("validation_prompts", [])
        self.validation_shortnames = validation_prompts.get("validation_shortnames", [])

    def upload_validation_folder(self, webhook_handler=None, override_path=None):
        if webhook_handler:
            webhook_handler.send(
                message=f"Uploading {'model' if override_path is None else 'intermediary checkpoint'} validation samples to Hugging Face Hub as `{self.repo_id}`."
            )
            webhook_handler.send_raw(
                structured_data={"status": "uploading_validation_samples"},
                message_type="training.status",
                message_level="info",
                job_id=StateTracker.get_job_id(),
            )
        if not self.config.push_to_hub:
            return
        try:
            upload_folder(
                repo_id=self._repo_id,
                folder_path=os.path.join(override_path or self.config.output_dir, "assets"),
                path_in_repo="assets/",
                commit_message="Validation images auto-generated by SimpleTuner",
            )
        except Exception as e:
            logger.error(f"Error uploading validation images to Hugging Face Hub: {e}")

    def upload_model(self, validation_images, webhook_handler=None, override_path=None):
        repo_folder = override_path or os.path.join(
            self.config.output_dir,
            "pipeline" if "lora" not in self.config.model_type else "",
        )
        save_training_config(repo_folder=repo_folder, config=self.config)
        save_model_card(
            model=self.model,
            repo_id=self.repo_id,
            images=validation_images,
            base_model=self.config.pretrained_model_name_or_path,
            train_text_encoder=self.config.train_text_encoder,
            prompt=self.config.validation_prompt,
            validation_prompts=self.validation_prompts,
            validation_shortnames=self.validation_shortnames,
            repo_folder=repo_folder,
        )
        if not self.config.push_to_hub:
            return
        if webhook_handler:
            webhook_handler.send(
                message=f"Uploading {'model' if override_path is None else 'intermediary checkpoint'} to Hugging Face Hub as `{self.repo_id}`."
            )
            webhook_handler.send_raw(
                structured_data={"status": "uploading_model"},
                message_type="training.status",
                message_level="info",
                job_id=StateTracker.get_job_id(),
            )

        try:
            self.upload_validation_folder(webhook_handler=webhook_handler, override_path=override_path)
        except:
            logger.error("Error uploading validation images to Hugging Face Hub.")

        attempt = 0
        while attempt < 3:
            attempt += 1
            try:
                if "lora" not in self.config.model_type:
                    self.upload_full_model(override_path=override_path)
                else:
                    self.upload_lora_model(override_path=override_path)
                    if self.config.use_ema:
                        self.upload_ema_model(override_path=override_path)
                break
            except Exception as e:
                if webhook_handler:
                    webhook_handler.send(
                        message=f"(attempt {attempt}/3) Error uploading model to Hugging Face Hub: {e}. Retrying..."
                    )
                    webhook_handler.send_raw(
                        structured_data={"status": "uploading_model"},
                        message_type="training.status",
                        message_level="info",
                        job_id=StateTracker.get_job_id(),
                    )
        repo_url = self._repo_url()
        if webhook_handler:
            webhook_handler.send(message=f"Model is now available [on Hugging Face Hub]({repo_url}).")
            webhook_handler.send_raw(
                structured_data={"status": "model_available"},
                message_type="training.status",
                message_level="info",
                job_id=StateTracker.get_job_id(),
            )
        return repo_url

    def upload_full_model(self, override_path=None):
        if not self.config.push_to_hub:
            return
        folder_path = os.path.join(self.config.output_dir, "pipeline")
        try:
            upload_folder(
                repo_id=self._repo_id,
                folder_path=override_path or folder_path,
                commit_message=self._commit_message(),
            )
        except Exception as e:
            logger.error(f"Failed to upload pipeline to hub: {e}")

    def upload_lora_model(self, override_path=None):
        if not self.config.push_to_hub:
            return
        checkpoint_root = override_path or self.config.output_dir
        lora_weights_path = os.path.join(checkpoint_root, LORA_SAFETENSORS_FILENAME)
        if not os.path.exists(lora_weights_path):
            raise FileNotFoundError(f"Missing required artifact: {lora_weights_path}")
        sla_path = os.path.join(checkpoint_root, SLA_ATTENTION_FILENAME)
        try:
            upload_file(
                repo_id=self._repo_id,
                path_in_repo=f"/{LORA_SAFETENSORS_FILENAME}",
                path_or_fileobj=lora_weights_path,
                commit_message=self._commit_message(),
            )
            if os.path.exists(sla_path):
                upload_file(
                    repo_id=self._repo_id,
                    path_in_repo=f"/{SLA_ATTENTION_FILENAME}",
                    path_or_fileobj=sla_path,
                    commit_message="SLA attention state auto-generated by SimpleTuner",
                )
            else:
                logger.debug("SLA attention state not found at %s; skipping upload.", sla_path)
            readme_path = os.path.join(checkpoint_root, "README.md")
            upload_file(
                repo_id=self._repo_id,
                path_in_repo="/README.md",
                path_or_fileobj=readme_path,
                commit_message="Model card auto-generated by SimpleTuner",
            )
        except Exception as e:
            logger.error(f"Failed to upload LoRA artifacts to hub: {e}")
            raise

    def upload_ema_model(self, override_path=None):
        if not self.config.push_to_hub or not self.config.use_ema:
            return
        try:
            check_ema_paths = ["transformer_ema", "unet_ema", "controlnet_ema", "ema"]
            # if any of the folder names are present in the checkpoint dir, we will upload them too
            for check_ema_path in check_ema_paths:
                print(f"Checking for EMA path: {check_ema_path}")
                ema_path = os.path.join(override_path or self.config.output_dir, check_ema_path)
                if os.path.exists(ema_path):
                    print(f"Found EMA checkpoint!")
                    upload_folder(
                        repo_id=self._repo_id,
                        folder_path=ema_path,
                        path_in_repo="/ema",
                        commit_message="LoRA EMA checkpoint auto-generated by SimpleTuner",
                    )
        except Exception as e:
            logger.error(f"Failed to upload LoRA EMA weights to hub: {e}")

    def find_latest_checkpoint(self):
        checkpoints = list(Path(self.config.output_dir).rglob("checkpoint-*"))
        highest_checkpoint_value = None
        highest_checkpoint = None
        if len(checkpoints) > 0:
            highest_checkpoint_value = 0
            for checkpoint in checkpoints:
                # split by -
                parts = checkpoint.stem.split("-")
                checkpoint_value = int(parts[-1])
                if checkpoint_value > highest_checkpoint_value:
                    highest_checkpoint_value = checkpoint_value
                    highest_checkpoint = checkpoint

        return highest_checkpoint

    def upload_latest_checkpoint(self, validation_images: dict, webhook_handler=None):
        checkpoint_path = self.find_latest_checkpoint()
        if checkpoint_path:
            logging.info(f"Checkpoint path: {checkpoint_path}")
            try:
                # Extract step number from checkpoint path (e.g., "checkpoint-50" -> 50)
                checkpoint_step = None
                checkpoint_name = os.path.basename(str(checkpoint_path))
                if "checkpoint-" in checkpoint_name:
                    try:
                        checkpoint_step = int(checkpoint_name.split("-")[1])
                    except (IndexError, ValueError):
                        logger.warning(f"Could not extract step number from checkpoint path: {checkpoint_path}")

                # Filter validation images to only include those for this checkpoint step
                filtered_images = {}
                if validation_images and checkpoint_step is not None:
                    validation_dir = os.path.join(self.config.output_dir, "validation_images")
                    if os.path.exists(validation_dir):
                        # Look for images with step_{checkpoint_step}_ in the filename
                        for shortname, images in validation_images.items():
                            filtered_images[shortname] = []
                            # Get the actual image files for this step
                            step_pattern = f"step_{checkpoint_step}_"
                            for img_file in os.listdir(validation_dir):
                                if step_pattern in img_file and shortname in img_file:
                                    img_path = os.path.join(validation_dir, img_file)
                                    try:
                                        from PIL import Image

                                        img = Image.open(img_path)
                                        filtered_images[shortname].append(img)
                                    except Exception as e:
                                        logger.warning(f"Could not load validation image {img_path}: {e}")
                            # Remove empty entries
                            if not filtered_images[shortname]:
                                del filtered_images[shortname]

                # Only use images that were actually generated at this checkpoint step.
                # Don't fall back to validation_images as those may be from a different step
                # (e.g., benchmark images from step 0) and shouldn't be associated with this checkpoint.
                # If no validation was run at this checkpoint step, we simply don't include validation
                # images in the model card.
                images_to_upload = filtered_images if filtered_images else None

                repo_url = self.upload_model(
                    validation_images=images_to_upload,
                    override_path=checkpoint_path,
                    webhook_handler=webhook_handler,
                )
                remote_path = self._repo_url(checkpoint_path.name)
                return remote_path, str(checkpoint_path), repo_url
            except Exception as e:
                logger.error(f"Failed to upload latest checkpoint: {e}")
                import traceback

                logger.error(traceback.format_exc())
        return None, str(checkpoint_path) if checkpoint_path else None, self._repo_url() if self.config.push_to_hub else None

    def upload_validation_images(self, validation_images, webhook_handler=None, override_path=None):
        logging.info(f"Validation images for upload: {validation_images}")
        if validation_images and len(validation_images) > 0:
            idx = 0
            for shortname, images in validation_images.items() if type(validation_images) is dict else validation_images:
                # print(f"Shortname {shortname} images: {images}")
                if type(images) is not list:
                    images = [images]
                sub_idx = 0
                for image in images:
                    image_path = os.path.join(
                        override_path or self.config.output_dir,
                        "assets",
                        f"image_{idx}_{sub_idx}.png",
                    )
                    image.save(image_path, format="PNG")
                    if not self.config.push_to_hub:
                        continue
                    attempt = 0
                    while attempt < 3:
                        attempt += 1
                        try:
                            upload_file(
                                repo_id=self._repo_id,
                                path_in_repo=f"/assets/image_{idx}_{sub_idx}.png",
                                path_or_fileobj=image_path,
                                commit_message="Validation image auto-generated by SimpleTuner",
                            )
                        except Exception as e:
                            if webhook_handler:
                                webhook_handler.send(
                                    message=f"(attempt {attempt}/3) Error uploading validation image to Hugging Face Hub: {e}. Retrying..."
                                )
                                webhook_handler.send_raw(
                                    structured_data={"status": "uploading_validation_samples"},
                                    message_type="training.status",
                                    message_level="info",
                                    job_id=StateTracker.get_job_id(),
                                )
                    sub_idx += 1
                    idx += 1
